"""
strategy file ,include all strategy nothing else
Version = 1.0
"""

import time
import math
import string
from random import randint, choice
import logging
import sys
import os
import json
import platform
from datetime import datetime, timedelta
import threading
import copy
import sqlite3
from tkinter.constants import S

if 'Linux' not in platform.platform():
    sys.path.append('D:\\machine_learning\\')
    sys.path.append('C:\\Users\\51735\\machine_learning\\')
    from backtest_optimize.tools import *
    # from new_neural_net_work import *

from pyalgotrade import strategy
from pyalgotrade.technical import ma
from pyalgotrade.technical import cross
from pyalgotrade import broker
from pyalgotrade.broker.backtesting import Broker, TradePercentage
from pyalgotrade.broker import fillstrategy

import talib
import pandas as pd
import numpy as np
import tushare as ts

from RealModule import KQ2ts
from enum_var import Action

pd.set_option('display.max_columns', None)
pd.set_option('display.max_rows', None)
pd.set_option('max_colwidth', 100)

# 全局变量定义
logger = logging.getLogger('Yhlz.Strategy')
THREAD_BUFFER = {}  # 线程缓冲区。


class YhlzStreategy(strategy.BacktestingStrategy):
    """
    重写策略类，给每个策略都加上可以检查是否移仓的方法
    """
    realTrade = False
    realBroker = ''

    def __init__(self, barFeed, cash=1000000):

        if self.realTrade:
            super().__init__(barFeed, self.realBroker)
        else:
            temp = Broker(cash, barFeed)
            temp.setAllowNegativeCash(True)
            fillstg = fillstrategy.DefaultStrategy(None)  # 设置一个不检查有成交量的fill strategy
            temp.setFillStrategy(fillstg)
            super().__init__(barFeed, temp)
            self.tech = {}

    def checkTransPosition(self):
        """
        用来检查是否这个策略需要移仓
        :return:
        """
        pass

    def transInstrument(self, instrument):
        """
        接受一个品种代码输入，输出一个对应的可以交易的代码，比如有时候策略计算的是主力合约的或者是指数的数据，但是下单却下单到别
        的地方，也可以在各个策略中继承然后单独实现这个功能，和检查换合约一样。
        :return:
        """
        return instrument

    def setRealTrade(self, realBroker, realTrade=True):
        """
        设置是否实盘交易。
        :param realBroker: 设置实盘交易的broker
        :param realTrade: 为True则是实盘交易
        :return:
        """
        self.realTrade = realTrade
        self._setBroker(realBroker)

    def on_order_finished(self, order):
        """订单完成的时候会被broker调用，这里只是写出来，
        用来继承，免得被调用的时候说没有这个方法"""
        pass

    def stop(self):
        # use for write info to file when abnormal or normal exit 
        pass 


class SMACrossOver(YhlzStreategy):

    def __init__(self, feed, instrument, context, dictOfDataDf):
        self.logger = logging.getLogger(f'Yhlz.strategy.{self.__class__.__name__}')
        super(SMACrossOver, self).__init__(feed)
        temp = self.getBroker()
        temp.setCommission(TradePercentage(0.0001))  # 单独设置手续费。
        if isinstance(instrument, list):
            self.__instrument = instrument[0]
        elif isinstance(instrument, str):
            self.__instrument = instrument
        else:
            raise Exception('不合法的instrument！')

        self.__position = None
        # self.prices = feed[self.__instrument].getPriceDataSeries()
        self.unsend_order = []
        # 这个策略只有一个品种，所以pop出来必然是那个。pop后这个键值对就不存在，不能取两次
        if self.realTrade:
            #  实盘的时候需要区分数据的频率 context.strategys['SMACrossOver'][0][1] 是从strategytorun里面获取了这个合约的运行频率。
            self.prices = feed[self.__instrument + context.strategys['SMACrossOver'][0][1]].getPriceDataSeries()
            self.logger.debug('实盘')
            # self.sma = talib.SMA(self.prices.values, 10)
            # self.sma1 = talib.SMA(self.prices.values, 20)
        else:
            self.prices = feed[self.__instrument].getPriceDataSeries()
            self.logger.debug('回测')
            length = len(dictOfDataDf[list(dictOfDataDf.keys())[0]])  # 拿出第一个df的长度
            self.sma = ma.SMA(self.prices, 25, maxLen=length)
            self.sma1 = ma.SMA(self.prices, 30, maxLen=length)
            self.tech = {self.__instrument: {'sma short': self.sma, 'sma long': self.sma1}}

    def getSMA(self):
        return self.sma

    def transInstrument(self, instrument):
        """
        将指数合约转换为主力合约
        :param instrument:
        :return:
        """
        if self.realTrade:
            if 'KQ.i' in instrument:  # 指数合约。
                return self.getBroker().allTick[instrument.replace('KQ.i', 'KQ.m')].underlying_symbol
            elif 'KQ.m' in instrument:  # 主力合约
                return self.getBroker().allTick[instrument].underlying_symbol
            else:  # 真实合约。
                return instrument

        else:
            return instrument

    def onBars(self, bars):
        # logger.debug('onbars')
        quantity = int(self.getBroker().getEquity()/10000)
        if self.realTrade:
            self.sma = talib.SMA(self.prices.values, 108)
            self.sma1 = talib.SMA(self.prices.values, 694)
            self.logger.debug('sma是')
            self.logger.debug(str(self.sma[-3:]))
            self.logger.debug(str(self.sma1[-3:]))
            have = abs(self.getBroker().getShares(self.transInstrument(self.__instrument)))
            

            if self.sma[-2] > self.sma1[-2] and self.sma[-3] < self.sma1[-3]:
                if have != 0:
                    # ret = self.getBroker().createLimitOrder(broker.Order.Action.BUY_TO_COVER,
                    #                                          self.transInstrument(self.__instrument), have)
                    # 改用算法单
                    ret = self.getBroker().create_algo_order(broker.Order.Action.BUY_TO_COVER,
                                                            self.transInstrument(self.__instrument), have)
                    self.getBroker().submitOrder(ret)
                    self.logger.critical('平仓1')
                #ret = self.getBroker().createLimitOrder(broker.Order.Action.BUY,
                #                                         self.transInstrument(self.__instrument), quantity)
                #self.getBroker().submitOrder(ret)
                    self.logger.critical('买1 ' + str(bars.getDateTime()))
                    self.unsend_order = (broker.Order.Action.BUY, self.transInstrument(self.__instrument), quantity)
                else:
                    self.logger.critical('买1 ' + str(bars.getDateTime()))
                    ret = self.getBroker().create_algo_order(broker.Order.Action.BUY, self.transInstrument(self.__instrument), quantity)
                    self.getBroker().submitOrder(ret)

            elif self.sma[-2] < self.sma1[-2] and self.sma[-3] > self.sma1[-3]:
                if have != 0:
                    # ret = self.getBroker().createLimitOrder(broker.Order.Action.SELL,
                    #                                          self.transInstrument(self.__instrument), have)
                    # 改用算法单
                    ret = self.getBroker().create_algo_order(broker.Order.Action.SELL,
                                                            self.transInstrument(self.__instrument), have)
                    self.getBroker().submitOrder(ret)
                    self.logger.critical('平仓2')
                #ret = self.getBroker().createLimitOrder(broker.Order.Action.SELL_SHORT,
                #                                         self.transInstrument(self.__instrument), quantity)
                #self.getBroker().submitOrder(ret)
                    self.logger.critical('卖1  ' + str(bars.getDateTime()))
                    self.unsend_order = (broker.Order.Action.SELL_SHORT, self.transInstrument(self.__instrument), quantity)
                else:
                    self.logger.critical('卖1  ' + str(bars.getDateTime()))
                    ret = self.getBroker().create_algo_order(broker.Order.Action.SELL_SHORT, self.transInstrument(self.__instrument), quantity)
                    self.getBroker().submitOrder(ret)

        else:
            if cross.cross_above(self.sma, self.sma1) > 0:
                if self.getBroker().getShares(self.__instrument) != 0:
                    ret = self.getBroker().createMarketOrder(Action.BUY_TO_COVER,
                                                             self.transInstrument(self.__instrument), quantity)
                    self.getBroker().submitOrder(ret)
                    self.logger.debug('平仓3')
                ret = self.getBroker().createMarketOrder(Action.BUY,
                                                         self.transInstrument(self.__instrument), quantity)
                self.getBroker().submitOrder(ret)
                self.logger.debug('买2' + str(bars.getDateTime()))
            elif cross.cross_below(self.sma, self.sma1) > 0:
                if self.getBroker().getShares(self.__instrument) != 0:
                    ret = self.getBroker().createMarketOrder(Action.SELL,
                                                             self.transInstrument(self.__instrument), quantity)
                    self.getBroker().submitOrder(ret)
                self.logger.debug('平仓4')
                ret = self.getBroker().createMarketOrder(Action.SELL_SHORT,
                                                         self.transInstrument(self.__instrument), quantity)
                self.getBroker().submitOrder(ret)
                self.logger.debug('卖2' + str(bars.getDateTime()))

    def on_order_finished(self, order):
        if self.unsend_order:
            # ret = self.getBroker().createLimitOrder(*self.unsend_order) 改用算法单
            ret = self.getBroker().create_algo_order(*self.unsend_order)
            self.getBroker().submitOrder(ret)
            self.logger.critical(f'开仓，方向、标的、数量是：{self.unsend_order[0]},{self.unsend_order[1]},{self.unsend_order[2]}')
            self.unsend_order = []
        
        self.logger.debug(f'SMACrossOver order finished filled price:{order.filledPrice}')

class TurtleTrade(YhlzStreategy):
    """
    海龟交易策略
    """

    def __init__(self, feed, instruments, context, dictOfDataDf, atrPeriod=14, short=20, long=55):
        """
        初始化
        :parm feed pyalgotrade 的feed对象，装了所有csv数据。类似于dict可以用中括号取值。
        :parm instrument 包含所有category的list，用的是简写，如‘rb’，‘ag’
        :param context context 对象，装所有变量
        :parm atrPeriod atr的周期
        :parm short 唐奇安通道的短期
        :parm long 唐奇安通道的长期
        :parm dictOfDataDf 包含所有数据的dict，其中每一个category是一个df
        """
        self.initialCash = 100000
        super(TurtleTrade, self).__init__(feed, self.initialCash)
        self.feed = feed
        if isinstance(instruments, list):  # 对于不是多个品种的情况，进行判断，如果是字符串，用list包裹在存储
            self.instruments = instruments
        else:
            self.instruments = [instruments]
        self.atrPeriod = atrPeriod
        self.short = short  # * 300  # 测试
        self.long = long  # * 300  # 测试
        self.dictOfDateDf = dictOfDataDf
        self.context = context
        self.generalTickInfo = pd.read_csv(context.general_ticker_info_path)
        if self.realTrade:
            pass
        else:
            self.tech = {}
            # self.max = 0  # 用来存储最大的历史数据有多长， 以计算此时到了哪一根k线，以方便用 -(self.max - self.i) 来取技术指标
            for instrument in self.instruments:
                atr = talib.ATR(np.array(self.dictOfDateDf[instrument]['High'], dtype=np.float),
                                np.array(self.dictOfDateDf[instrument]['Low'], dtype=np.float),
                                np.array(self.dictOfDateDf[instrument]['Close'], dtype=np.float),
                                self.atrPeriod)  # 返回的ndarray
                long_upper = talib.MAX(np.array(self.dictOfDateDf[instrument]['High'], dtype=np.float), self.long)
                long_lower = talib.MIN(np.array(self.dictOfDateDf[instrument]['Low'], dtype=np.float), self.long)
                short_upper = talib.MAX(np.array(self.dictOfDateDf[instrument]['High'], dtype=np.float), self.short)
                short_lower = talib.MIN(np.array(self.dictOfDateDf[instrument]['Low'], dtype=np.float), self.short)

                self.tech[instrument] = {'atr': atr, 'long upper': long_upper, 'long lower': long_lower,
                                         'short upper': short_upper, 'short lower': short_lower}
                # if len(atr) > self.max:
                #     self.max = len(atr)

            # self.i = 0  # 计数，用于确定计数指标的位置

        self.openPriceAndATR = {}  # 用于记录每个品种的开仓价格与当时的atr
        self.equity = 0

    def onBars(self, bars):
        barTime = bars.getDateTime()
        self.equity = self.getBroker().getEquity()
        if self.equity < 0:
            return
        order = []
        allAtr = {}
        postion = self.getBroker().getPositions()
        readyInstrument = bars.getInstruments()
        for instrument in self.instruments:
            # t1 = time.time()
            # atr = talib.ATR(np.array(self.feed[instrument].getHighDataSeries()),
            #                 np.array(self.feed[instrument].getLowDataSeries()),
            #                 np.array(self.feed[instrument].getCloseDataSeries()), self.atrPeriod)[-1]  # 返回的ndarray
            # if np.isnan(atr):  # 为nan说明数据还不够，不做计算。
            #     continue
            # allAtr[instrument] = atr
            # quantity = self.getQuantity(instrument, atr)
            # long_upper = talib.MAX(np.array(self.feed[instrument].getHighDataSeries()), self.long)
            # long_lower = talib.MIN(np.array(self.feed[instrument].getLowDataSeries()), self.long)
            # short_upper = talib.MAX(np.array(self.feed[instrument].getHighDataSeries()), self.short)
            # short_lower = talib.MIN(np.array(self.feed[instrument].getLowDataSeries()), self.short)

            if instrument not in readyInstrument:  # 如果此时没有这个品种的bar 说明还没开始或者别的品种的夜盘和它时间冲突
                temp = self.dictOfDateDf[instrument][self.dictOfDateDf[instrument]['Date Time'] < barTime]
                if temp.index.empty:  # 如果index是空值，则说明此时这个品种还没有开始有数据
                    i = 0
                else:
                    i = temp.index[-1]  # 有数据，就说明是中间有夜盘时间不对齐的问题，用前一天的atr来代替

                atr = self.tech[instrument]['atr'][i]
                allAtr[instrument] = atr
                # 对于这种取后一天的atr来假装，避免后面取atr 之前开仓所以有持仓，但此时没有atr的报错
                continue
            i = self.dictOfDateDf[instrument][self.dictOfDateDf[instrument]['Date Time'] == barTime].index[0]
            atr = self.tech[instrument]['atr'][i]  # * 50#测试
            allAtr[instrument] = atr

            if np.isnan(atr):  # 为nan说明数据还不够，不做计算。
                continue

            #  找到这个时间在df中的位置
            quantity = self.getQuantity(instrument, atr)
            long_upper = self.tech[instrument]['long upper'][i - 1:i + 1]  # 取出到此时的最后两个
            long_lower = self.tech[instrument]['long lower'][i - 1:i + 1]
            short_lower = self.tech[instrument]['short lower'][i - 1:i + 1]
            short_upper = self.tech[instrument]['short upper'][i - 1:i + 1]
            # t2 = time.time()
            # print(t2 - t1)
            # 开仓。
            if long_upper[-1] > long_upper[-2] and postion.get(instrument, 0) == 0:  # 当期上界变高表示创新高，新低同理
                ret = self.getBroker().createMarketOrder(Action.BUY, instrument, quantity)
                self.openPriceAndATR[instrument] = [bars.getBar(instrument).getClose(), atr]  # 默认以收盘价开仓
                print('open long')
                print(bars.getDateTime())
                print(instrument)
                print(postion)
                orders.append(ret)

            elif long_lower[-1] < long_lower[-2] and postion.get(instrument, 0) == 0:
                ret = self.getBroker().createMarketOrder(Action.SELL_SHORT, instrument, quantity)
                self.openPriceAndATR[instrument] = [bars.getBar(instrument).getClose(), atr]  # 默认以收盘价开仓
                print('open short')
                print(bars.getDateTime())
                print(instrument)
                print(postion)
                orders.append(ret)
            # 平仓
            elif short_upper[-1] > short_upper[-2] and postion.get(instrument, 0) < 0:  # 平空
                ret = self.getBroker().createMarketOrder(Action.BUY_TO_COVER, instrument,
                                                         abs(postion[instrument]))
                print('close short')
                print(bars.getDateTime())
                print(instrument)
                print(postion)
                orders.append(ret)
                self.openPriceAndATR.pop(instrument)  # 去掉指定的持仓


            elif short_lower[-1] < short_lower[-2] and postion.get(instrument, 0) > 0:  # 平多
                ret = self.getBroker().createMarketOrder(Action.SELL, instrument,
                                                         abs(postion[instrument]))
                print('close long')
                print(bars.getDateTime())
                print(instrument)
                print(postion)
                orders.append(ret)
                self.openPriceAndATR.pop(instrument)  # 去掉指定的持仓

            # 加仓 或止损
            elif instrument in self.openPriceAndATR:  # 表示已有持仓
                if postion.get(instrument, 0) > 0:  # 持有多仓
                    if self.openPriceAndATR[instrument][0] + 0.5 * atr < bars.getBar(instrument).getClose():
                        # 且价格超过了上次开仓价加0.5atr在加仓
                        ret = self.getBroker().createMarketOrder(Action.BUY, instrument,
                                                                 quantity)
                        self.openPriceAndATR[instrument][0] = bars.getBar(instrument).getClose()
                        print('add long')
                        print(bars.getDateTime())
                        print(instrument)
                        print(postion)
                        orders.append(ret)

                    elif self.openPriceAndATR[instrument][0] - 0.5 * atr > bars.getBar(instrument).getClose():
                        # 且价格小于了上次开仓价减0.5atr，则止损
                        ret = self.getBroker().createMarketOrder(Action.SELL, instrument,
                                                                 abs(postion[instrument]))
                        self.openPriceAndATR.pop(instrument)
                        print('stop long')
                        print(bars.getDateTime())
                        print(instrument)
                        print(postion)
                        orders.append(ret)

                elif postion.get(instrument, 0) < 0:  # 持有空仓
                    if self.openPriceAndATR[instrument][0] - 0.5 * atr > bars.getBar(instrument).getClose():
                        # 且价格超过了开仓价加0.5atr 则加空
                        ret = self.getBroker().createMarketOrder(Action.SELL_SHORT, instrument,
                                                                 quantity)
                        self.openPriceAndATR[instrument][0] = bars.getBar(instrument).getClose()
                        print('add short')
                        print(bars.getDateTime())
                        print(instrument)
                        print(postion)
                        orders.append(ret)

                    elif self.openPriceAndATR[instrument][0] + 0.5 * atr < bars.getBar(instrument).getClose():
                        # 且价格大于了上次开仓价加0.5atr，则止损
                        ret = self.getBroker().createMarketOrder(Action.BUY_TO_COVER, instrument,
                                                                 abs(postion[instrument]))
                        self.openPriceAndATR.pop(instrument)
                        print('stop short')
                        print(bars.getDateTime())
                        print(instrument)
                        print(postion)
                        orders.append(ret)
            # t3 = time.time()
            # print(t3 - t2)

        # t3 = time.time()
        allPos = 0
        for instrument in postion:
            allPos += round(postion[instrument] / self.getQuantity(instrument, allAtr[instrument]))
            # 看某个品种有多少个单位的持仓，按照现在的atr来计算
        allPos = abs(allPos)
        if allPos >= 10:
            open_mark = False  # 达到10个单位，不再开仓
        else:
            open_mark = True

        for item in orders:
            item.setGoodTillCanceled(True)
            item.setAllOrNone(True)
            action = item.getAction()
            if action == Action.SELL or action == Action.BUY_TO_COVER:  # 平仓的都可以
                self.getBroker().submitOrder(item)
            else:  # 开仓的情况
                if open_mark:
                    ins = item.getInstrument()
                    exist = round(postion.get(ins, 0) / self.getQuantity(ins, allAtr[ins]))
                    if exist <= 3:  # 如果单个品种小于3个单位的持仓，就可以开
                        self.getBroker().submitOrder(item)
                        allPos += 1
                    if allPos >= 10:
                        open_mark = False
        # self.i += 1  # 自增以移向下一个计数指标的值
        t4 = time.time()
        # print(t4 - t3)

    def getQuantity(self, instrument, atr):
        """
        计算此时可以开多少张
        :return:
        """

        quantity = self.equity
        # quantity = self.initialCash  # 测试, 固定资产开仓，不随资产增长，避免回测到后期钱太多的问题

        temp = instrument.split('.')
        KQFileName = 'KQi@' + temp[0] + temp[1].strip(string.digits)

        temp = self.generalTickInfo.loc[self.generalTickInfo['index_name'] == KQFileName, 'contract_multiplier']
        if temp.empty:
            # 检查获取到的series是否为空再去iloc避免出现未知的异常。
            raise Exception(f'generalTickInfo里面没找到合约:{KQFileName},请检查！')
        KQmultiplier = temp.iloc[0]
        res = int(quantity / atr / 100 / KQmultiplier)  # 向下取整
        # res = int(quantity / atr / 100)  # 由于目前回测系统没有考虑合约乘数，不需要除以合约乘数

        if res:
            return res
        else:
            return 0  # 至少开1手

        # 账户的1%的权益，除去atr值，再除去合约乘数，即得张数。表示一个atr的标准波动让账户的权益变动1%


class RealTurtleTrade(TurtleTrade):
    """用于实盘的海龟交易法"""

    def __init__(self, feed, instruments, *args):
        # global THREAD_BUFFER
        self.logger = logging.getLogger(f'Yhlz.strategy.{self.__class__.__name__}')
        super(RealTurtleTrade, self).__init__(feed, instruments, *args)
        self.broker = self.getBroker()  # 减少每次获取broker的时间开销
        self.feed = feed
        # context 对象包含了tushare的句柄。以及tick对象
        self.context = args[0]
        self.unSendOrder = []  # 记录未发送的订单
        self.unfinished_order = {}

        # 读取json文件记录的关于这个策略的所有信息
        with open(self.context.Record_path) as f:
            self.allInfo = json.load(f)['RealTurtleTrade']
        self.logger.debug(f'策略所有信息是 :\n{self.allInfo}')

        self.PriceVolumeAtr = self.allInfo['PriceVolumeAtr']
        # contractCorrelation记录了一个数据库路径，里面有所有合约的关系
        
        if 'Linux' not in platform.platform():
            conn = sqlite3.connect(self.allInfo['contractCorrelation'][0])
        else:
            conn = sqlite3.connect('/'.join(self.allInfo['contractCorrelation'][0].split('\\')))
        sql = f"""SELECT * FROM {self.allInfo['contractCorrelation'][1]}"""
        self.contractCorrelation = pd.read_sql(sql, conn)
        conn.close()

        # 获取所有合约的合约基本信息。
        self.contractInfo = self.context.contractInfo

        now = datetime.now()
        # ts 接口 通过给定交易所，获取交易日历
        self.tradeDate = self.context.tsObj.trade_cal(exchange='SHFE', start_date=now.strftime('%Y%m%d'),
                                                      end_date=(now + timedelta(days=400)).strftime('%Y%m%d'))

        self.logger.debug(f'最终合约信息是：\n{self.contractInfo}')
        self.logger.debug(f'交易日历是：\n{self.tradeDate}')

        # 计算合约更换标注，第一次ontick计算，之后不计算。 在初始化的时候
        self.ChangeContractMark = True
        self.exitMark = False  # 控制状态检查线程是否退出。
        # 计算atr， 本来是准备实时计算atr，但是这样会带来很大的计算量，于是决定一天计算一次即可，calAtr只在计算好的dict里面获取数据
        self.allAtr = {}
        for symbol in self.feed:
            # feed里面的symbol默认是带1d所以需要去掉
            self.allAtr[symbol.replace('1d', '')] = talib.ATR(np.array(self.feed[symbol].getHighDataSeries()),
                                                              np.array(self.feed[symbol].getLowDataSeries()),
                                                              np.array(self.feed[symbol].getCloseDataSeries()),
                                                              self.atrPeriod)[-1]
        equity = self.broker.accountInfo.balance
        for key in self.allAtr:  # 避免有些没有k线的合约出现atr为0的情况,将atr赋值为当前权益，则计算出来的张数必定为0，类似于atr无穷大的意思
            if self.allAtr[key] == 0:
                self.allAtr[key] = equity
        self.logger.debug(f'allAtr是：{self.allAtr}')

        # 计算唐奇安通道
        self.TQAChannel = {}
        for symbol in self.feed:
            if '1d' not in symbol:
                continue  # 只计算日线频率，其他频率本策略不处理。
            symbol = symbol.replace('1d', '')
            if symbol not in self.context.allTick:  
                # tick和k线都有可能订阅失败，没有订阅到tick的不会触发ontick不需要计算唐奇安通道
                continue

            self.TQAChannel[symbol] = {}
            if self.broker.inTradeTime(self.context.allTick[symbol]):  # 开盘的时候需要不计算最新的。
                self.TQAChannel[symbol]['long_upper'] = max(self.feed[symbol + '1d'].getHighDataSeries()[-56: -1])
                self.TQAChannel[symbol]['long_lower'] = min(self.feed[symbol + '1d'].getLowDataSeries()[-56: -1])
                self.TQAChannel[symbol]['short_upper'] = max(self.feed[symbol + '1d'].getHighDataSeries()[-21: -1])
                self.TQAChannel[symbol]['short_lower'] = min(self.feed[symbol + '1d'].getLowDataSeries()[-21: -1])

            else:  # 未开盘时k线需要计算最新值
                self.TQAChannel[symbol]['long_upper'] = max(self.feed[symbol + '1d'].getHighDataSeries()[-55:])
                self.TQAChannel[symbol]['long_lower'] = min(self.feed[symbol + '1d'].getLowDataSeries()[-55:])
                self.TQAChannel[symbol]['short_upper'] = max(self.feed[symbol + '1d'].getHighDataSeries()[-20:])
                self.TQAChannel[symbol]['short_lower'] = min(self.feed[symbol + '1d'].getLowDataSeries()[-20:])
        self.logger.debug(f'唐奇安通道是：{self.TQAChannel}')

        self.openMark = {}  # 用于记录这个标的当前是否可以开仓
        self.checkOpenMark = True  # 用于判断是否需要计算各合约是否可以开仓。


        self.logger.debug('初始化完成！')

    def checkFilled(self):
        """用于将订单成交后的成交价计算然后得到止损价记录"""
        # global THREAD_BUFFER  # 将所有委托的成交信息都写入到这个缓存内交互给原主线程。
        # self.logger.debug(f'线程缓冲变量初始内容：\n{THREAD_BUFFER}')
        # while True:
        #     time.sleep(1)
        if self.exitMark:
            self.logger.info('收到退出信号，退出检查线程。')
            # break  # 退出标致置True就退出。
            return
        self.logger.debug(f"check filled's order{self.broker.getOrders()}")
        for order in self.broker.getOrders()[:]:
            if order.is_dead:
                self.broker.getOrders().remove(order)
                atr = self.calAtr(order.contract)
                # price = order.filledPrice + atr if order.direction == 'BUY' \
                #     else order.filledPrice - atr
                if not order.open:
                    continue
                if order.volumeLeft != 0:  # 部分成交或者未成交被撤单了
                    if not self.unfinished_order[order.contract]:
                        self.unfinished_order[order.contract] = []
                    self.unfinished_order[order.contract].append([order.filledVolume, order.filledPrice])

                if order.volumeLeft == 0 and not hasattr(order, 'change_contract'):
                    # 只append处理非移仓的单子，移仓的单子直接修改pricevolumeatr，避免checkfilled的逻辑过于复杂
                    if self.unfinished_order[order.contract]:
                        amount = 0
                        volume = 0
                        for row in self.unfinished_order[order.contract]:
                            amount += row[0] * row[1]
                            volume += row[0]
                        price = float(amount/volume)
                    else:
                        price = order.filledPrice
                        volume = order.filledVolume

                    if order.contract not in self.PriceVolumeAtr:
                        self.PriceVolumeAtr[order.contract] = []
                    self.PriceVolumeAtr[order.contract].append({'price': price, 'volume': volume,
                                                          'atr': atr, 'direction': order.direction})
                    self.unfinished_order[order.contract] = []

                with open(self.context.Record_path, 'r') as f:
                    temp = json.load(f)

                with open(self.context.Record_path, 'w') as f:
                    temp['RealTurtleTrade']['PriceVolumeAtr'] = self.PriceVolumeAtr
                    json.dump(temp, f)
                self.logger.debug(f'dead的order是：{order}')
                self.logger.debug(f'：PriceVolumeAtr是：{self.PriceVolumeAtr}')

    def onTickTest(self, ticks):
        if datetime.now().hour == 20:  # 20:59会有一个tick推送，但是此时是不能下单的，所以要跳过。
            return
        # 获取文件记录的当前权益。
        self.position = self.broker.getPositions('RealTurtleTrade')
        # self.logger.debug(f'计算完换合约时的持仓情况是：\n{self.position}')
        self.equity = self.calAvaEquity()
        if self.checkOpenMark:  # 不能在初始化中获取资金，于是在第一次ontick的时候进行计算是否该合约可以开仓
            self.checkOpenMark = False
            for symbol in self.feed:
                if '1d' not in symbol:
                    continue
                symbol = symbol.replace('1d', '')
                self.openMark[symbol] = self.initCheckOpen(symbol)
        self.checkFilled()
        self.unSendOrder = []
        for tick in ticks:
            self.logger.debug(f'合约：{tick.instrument_id}， tick：{tick}----------------------------------------------')
            symbol = tick.instrument_id

            haveMark = self.position.contract.isin([tick.instrument_id]).any()
            if symbol not in self.TQAChannel:  
                # 如果没有订阅全部合约的k线，那么多策略运行别的策略订阅了的tick，这个策略的k线可能会没有会报错
                continue

            long_upper = self.TQAChannel[symbol]['long_upper']
            long_lower = self.TQAChannel[symbol]['long_lower']

            # 如果order中还有未结的订单，直接跳过，不做逻辑判断
            for order in self.broker.getOrders():
                if order.virContract == symbol:
                    self.logger.info(f'合约有未结单，跳过tick判断！')
                    continue

            self.logger.info(f'have是:{haveMark}')
            if not haveMark:
                if tick.last_price > long_upper:
                    if self.handleOnTick(tick):
                        ret = self.broker.createLimitOrder(Action.BUY, symbol, self.N)
                        self.logger.debug('open long')
                        self.unSendOrder.append(ret)
                elif tick.last_price < long_lower:
                    if self.handleOnTick(tick):
                        ret = self.broker.createLimitOrder(Action.SELL_SHORT, symbol, self.N)
                        self.logger.debug('open short')
                        self.unSendOrder.append(ret)
                else:
                    continue  # 没有持仓不符合开仓条件的直接返回。避免计算多耗费时间。正常情况绝大多数合约都应该在这里返回
            elif haveMark:
                openMark = self.handleOnTick(tick)
                have = self.broker.getShares(symbol)
                short_upper = self.TQAChannel[symbol]['short_upper']
                short_lower = self.TQAChannel[symbol]['short_lower']
                # 平仓
                if tick.last_price > short_upper and have < 0:  # 平空
                    ret = self.broker.createLimitOrder(Action.BUY_TO_COVER, symbol,
                                                        have)
                    self.logger.debug('close short')

                    self.unSendOrder.append(ret)
                    self.PriceVolumeAtr.pop(symbol)  # 去掉指定的持仓
                    continue  # 避免pop之后，后面判断出问题

                elif tick.last_price < short_lower and have > 0:  # 平多
                    ret = self.broker.createLimitOrder(Action.SELL, symbol,
                                                        have)
                    self.logger.debug('close long')

                    self.unSendOrder.append(ret)
                    self.PriceVolumeAtr.pop(symbol)  # 去掉指定的持仓
                    continue  # 避免pop之后，后面判断出问题

                if self.position.loc[self.position['contract'] == symbol, 'direction'].iloc[0] == 'BUY':  # 持有多仓
                    addPrice = 0
                    self.logger.debug('持有多仓')
                    for i, item in enumerate(self.PriceVolumeAtr[symbol][:]):
                        if self.PriceVolumeAtr[symbol][i]['price'] + 0.5 * self.PriceVolumeAtr[symbol][i]['atr'] > addPrice:
                            addPrice = self.PriceVolumeAtr[symbol][i]['price'] + 0.5 * self.PriceVolumeAtr[symbol][i]['atr']
                            self.logger.debug(f'add price:{addPrice}')


                    if addPrice < tick.last_price and addPrice != 0:
                        # 当出现某些异常，持仓记录为空，currentAccount不为空，导致cutprice为无穷然后加仓的情况
                        # 且价格超过了上次开仓价加0.5atr再加仓
                        if openMark:
                            ret = self.broker.createLimitOrder(Action.BUY, symbol, self.N)
                            self.logger.debug(f'add long2 add price:{addPrice}, pricevolumeatr:{self.PriceVolumeAtr}')
                            self.unSendOrder.append(ret)
                    
                    for i, item in enumerate(self.PriceVolumeAtr[symbol][:]):
                        if item['price'] - 0.5 * item['atr'] > tick.last_price:
                            # 且价格小于了上次开仓价减0.5atr，则止损
                            ret = self.broker.createLimitOrder(Action.SELL, symbol, item['volume'])
                            self.PriceVolumeAtr[symbol].pop()  # 先去掉一个委托
                            self.logger.debug(f'stop long1, add price:{addPrice}, pricevolumeatr:{self.PriceVolumeAtr}')

                            self.unSendOrder.append(ret)
                            if not self.PriceVolumeAtr[symbol]:  # 如果去掉后，该品种持仓记录变成了空列表，则直接去掉这个空列表
                                self.PriceVolumeAtr.pop(symbol)
                                break

                elif self.position.loc[self.position['contract'] == symbol, 'direction'].iloc[0] == 'SELL':  # 持有空仓
                    cutPrice = float("inf")
                    self.logger.debug('持有空仓')
                    for i, item in enumerate(self.PriceVolumeAtr[symbol][:]):
                        if self.PriceVolumeAtr[symbol][i]['price'] - 0.5 * self.PriceVolumeAtr[symbol][i]['atr'] < cutPrice:
                            cutPrice = self.PriceVolumeAtr[symbol][i]['price'] - 0.5 * self.PriceVolumeAtr[symbol][i]['atr']
                            self.logger.debug(f'cut price:{cutPrice}')


                    if tick.last_price < cutPrice and not math.isinf(cutPrice):
                        # 当出现某些异常，持仓记录为空，currentAccount不为空，导致cutprice为无穷然后加仓的情况
                        # 且价格超过了开仓价加0.5atr 则加空
                        if openMark:
                            ret = self.broker.createLimitOrder(Action.SELL_SHORT, symbol, self.N)
                            self.logger.debug(f'add short4, pricevolumeatr:{self.PriceVolumeAtr}')
                            self.unSendOrder.append(ret)

                    for i, item in enumerate(self.PriceVolumeAtr[symbol][:]):
                        if item['price'] + 0.5 * item['atr'] < tick.last_price:
                            # 且价格大于了上次开仓价加0.5atr，则止损
                            ret = self.broker.createLimitOrder(Action.BUY_TO_COVER, symbol, item['volume'])
                            self.PriceVolumeAtr[symbol].pop()  # 先去掉一个委托
                            self.logger.debug(f'stop short3, cut price:{cutPrice}, pricevolumeatr:{self.PriceVolumeAtr}')
                            self.unSendOrder.append(ret)
                            if not self.PriceVolumeAtr[symbol]:  # 如果去掉后，该品种持仓记录变成了空列表，则直接去掉这个空列表
                                self.PriceVolumeAtr.pop(symbol)
                                break

        # 处理换合約
        if self.ChangeContractMark:  # 只在每次开盘的第一个tick计算，后面的就不计算了。
            self.logger.info('要检查换合约')
            self.calChangeContract()
            self.ChangeContractMark = False
            self.logger.debug('检查换合约完成！')

            # 更改record记录中的原合约变为新的主力合约
            def change_name(contract, main_contract):
                new_contract = []
                for item in self.PriceVolumeAtr[contract]:
                    temp = {}
                    temp['volume'] = item['volume']
                    temp['direction'] = item['direction']
                    temp['atr'] = self.allAtr[main_contract]
                    temp['price'] =  self.context.allTick[main_contract].last_price - \
                                     (self.context.allTick[contract].last_price - item['price']) / item['atr'] * temp['atr']
                    new_contract.append(temp)
                self.PriceVolumeAtr[main_contract] = new_contract
                self.PriceVolumeAtr.pop(contract)

            for i in range(len(self.position)):
                symbol = self.position.loc[i, 'contract']
                if type(symbol) != str:
                    continue
                if not self.position.loc[self.position['contract'] == symbol, 'changeFlag'].empty and \
                        self.position.loc[self.position['contract'] == symbol, 'changeFlag'].iloc[0] == 1:
                    # 有查找这个合约不是空，（有持仓）并且还换合约的标志位为1就进行换合约。
                    main_symbol = self.getMainContract(symbol)
                    self.logger.debug(f'主力合约是：{main_symbol}')
                    if self.checkOpen(main_symbol):  # 主力合約是否可以開倉
                        self.N = self.getQuantity(symbol, self.calAtr(symbol))
                        if self.position.loc[self.position['contract'] == symbol, 'direction'].iloc[0] == 'BUY':
                            self.logger.debug(f'移多仓')
                            volume = self.position.loc[self.position['contract'] == symbol, 'volume'].iloc[0]
                            ret = self.broker.createLimitOrder(Action.SELL, symbol, volume)
                            ret.change_contract = True
                            self.unSendOrder.append(ret)
                            self.logger.debug(f'{ret}')
                            # 有几条记录就说明要开多少个N不管过去是多少手，只按现在的标准单位来
                            ret = self.broker.createLimitOrder(Action.BUY, main_symbol, volume)

                            ret.change_contract = True
                            self.unSendOrder.append(ret)
                            self.logger.debug(f'{ret}')
                            change_name(symbol, main_symbol)

                        elif self.position.loc[self.position['contract'] == symbol, 'direction'].iloc[0] == 'SELL':
                            self.logger.debug(f'移空仓')
                            volume = self.position.loc[self.position['contract'] == symbol, 'volume'].iloc[0]
                            ret = self.broker.createLimitOrder(Action.BUY_TO_COVER, symbol, volume)
                            ret.change_contract = True
                            self.unSendOrder.append(ret)
                            self.logger.debug(f'{ret}')
                            ret = self.broker.createLimitOrder(Action.SELL_SHORT, main_symbol, volume)
                            ret.change_contract = True
                            self.unSendOrder.append(ret)
                            self.logger.debug(f'{ret}')
                            change_name(symbol, main_symbol)
                    else:
                        assert False, f'主力合约不可以开仓! 请检查并处理: {main_symbol}'
        # 报单
        if self.unSendOrder:
            self.logger.debug('报单！')
        for order in self.unSendOrder:
            self.logger.critical(f'{order}')
            self.broker.submitOrder(order)


    def handleOnTick(self, tick):
        """优化ontick计算时间的需要将这些是否可以交易的判断合起来，因为需要在多个if else分支中判断"""

        # self.logger.debug(f'1.2{time.time()}')
        orders = self.broker.getOrders()
        if orders:
            for order in orders[:]:
                # 判断是否有一个品种的订单，一个品种只要有一个合约有持仓或订单，那么这个合约就不能再开仓了。
                if order.contract.split('.')[1].strip(string.digits) == tick.product_id:
                    self.logger.debug(f'{tick.underlying_symbol}~{order.contract} 还有同品种订单未完！')
                    return False
                # if order.contract == tick.underlying_symbol:  # 如果有未成交的单只阻塞对应合约的tick判断。
                #     self.logger.debug(f'{tick.underlying_symbol} 还有本合约订单未处理完，直接返回！')
                #     return
        for order in self.unSendOrder:
            # 检查是否有未发的同品种订单，避免同时发出多个单
            if order.contract.split('.')[1].strip(string.digits) == tick.product_id:
                self.logger.debug(f'{tick.underlying_symbol}~{order.contract} 还有同品种订单未完！')
                return False

        # self.logger.debug(f'2{time.time()}')
        # self.logger.debug(f'3{time.time()}')
        self.logger.debug(f'获取到的成交价量内容是：{self.PriceVolumeAtr}')
        symbol = tick.instrument_id
        self.logger.debug(f'标的是：{symbol}')
        # self.logger.debug(f'4{time.time()}')
        # 检查是否可以开仓，不能直接返回。
        if not self.checkOpen(symbol):
            self.logger.debug(f'合约不能开仓直接退出！')
            return False
        # self.logger.debug(f'5{time.time()}')
        # 判断合约相关性。限制开仓
        # 在record 文件里面的本策略下有一个叫contractCorrelation的字段 内容是{合约：{合约："相关性强度"}} 来判断市场的关联度。
        self_ = 0  # 同合约
        close = 0  # 高度关联计数
        loose = 0  # 松散关联
        long = 0  # 同方向
        short = 0
        all_ = 0  # 所有的
        indexNow = 'KQ.i@' + symbol.strip(string.digits)
        for contract in self.PriceVolumeAtr:
            count = len(self.PriceVolumeAtr[contract])
            code = symbol.split('.')[1].strip(string.digits).upper()
            loose_temp = self.contractCorrelation.loc[self.contractCorrelation['code'] == code, 'fen_lei'].iloc[0]
            close_temp = self.contractCorrelation.loc[self.contractCorrelation['code'] == code, 'xi_fen_lei'].iloc[0]
            # 对于每个合约记录的是一个list的dict因为一个合约可能有多次记录。
            # 根据记录的长度来增加
            if 'KQ.i@' + contract.strip(string.digits) == indexNow:
                if contract != tick.instrument_id:
                    # 对于已经有持仓的品种，如果不是同一个合约，不在持仓，
                    # 例如 rb2110持仓了，那么对于rb其他的合约直接跳过不计算。
                    self.logger.debug(f'{contract}::{tick.instrument_id} 同品种直接跳过。')
                    return False
                self_ += count  # 同品种
            if contract.strip(string.digits) in \
                self.contractCorrelation.loc[self.contractCorrelation['fen_lei'] == loose_temp, 'code'].to_list():
                loose += count
            if contract.strip(string.digits) in \
                    self.contractCorrelation.loc[self.contractCorrelation['xi_fen_lei'] == close_temp, 'code'].to_list():
                close += count

            # if 'KQ.i@' + contract.strip(string.digits) in self.contractCorrelation.get(indexNow, {}).get('close',
            #                                                                                              {}):  # 如果没有记录这个合约就返回空dict避免报错。
            #     close += count  # 紧密关联
            # if 'KQ.i@' + contract.strip(string.digits) in self.contractCorrelation.get(indexNow, {}).get('relax', {}):
            #     loose += count  # 松散关联
            for pos in self.PriceVolumeAtr[contract]:
                if pos['direction'] == 'BUY':
                    long += 1  # 多头
                if pos['direction'] == 'SELL':
                    short += 1  # 空头
            all_ += count  # 每一次所有的都加1

        # self.logger.debug(f'6{time.time()}')
        self.logger.debug(f'同合约：{self_}, 高度关联：{close}, 松散关联：{loose}, 多头：{long}, 空头：{short}, 全部：{all_}')
        # self.logger.debug(f'6.1{time.time()}')
        if self_ >= 4:  # 同品种持仓大于4个
            return False
        if close + self_ >= 6:  # 紧密关联大于6个
            return False
        if close + self_ + loose >= 10:  # 松散关联大于10个：
            return False
        if long >= 12:
            return False
        if short >= 12:
            return False
        if all_ >= 24:  # 全部
            return False
        # self.logger.debug(f'7{time.time()}')
        self.equity = self.calAvaEquity()
        self.logger.debug(f'可用权益是：{self.equity}')
        # self.logger.debug(f'7.1{time.time()}')
        atr = self.calAtr(symbol)  # 返回的ndarray
        # self.logger.debug(f'7.2{time.time()}')
        self.logger.debug(f'atr是：{atr}')
        # self.logger.debug(f'7.3{time.time()}')
        # 计算标准单位N
        self.N = self.getQuantity(symbol, atr)
        # self.logger.debug(f'7.4{time.time()}')
        self.logger.debug(f'标准单位是：{self.N}')
        if not self.N:
            return False

        return True


    def onBars(self, bars):
        """由于海龟策略需要订阅日线的bar，所以如果不定义onbar就会调用到turtletrade的onbar，出现不兼容情况，重载onbar直接返回"""
        return


    def calChangeContract(self):
        """检查是否需要更换持仓"""
        self.logger.debug(f'检查换合约前持仓：\n{self.position}')
        for i in self.position.index:
            if self.position.loc[i, 'account'] == 1:  # 跳过账户信息行。
                continue
            if not self.checkOpen(self.position.loc[i, 'contract']):  # 持仓中的合约已经不可以再持仓了。
                self.position.loc[i, 'changeFlag'] = 1  # 标识需要移仓
        self.logger.debug(f'检查换合约后持仓：\n{self.position}')


    def getMainContract(self, symbol: str):
        """对于输入的合约输出主力合约，可以下单的, 输入时交易所合约，比如rb2110"""
        self.logger.debug(f'合约是：{symbol}')
        contractName = symbol.strip(string.digits)  # 去掉前面的交易所代码，去掉后面的数字。只要合约标识
        return self.broker.getMainContract(contractName)
        # 去所有的tick里面找主力合约的代码


    def calAvaEquity(self):
        """实盘用来计算可用权益的方法，海龟策略能使用的资金并不一定是在currentAccount里面分配的资金"""
        breakEven = self.allInfo['breakEven']  # 获取保本权益
        equity = self.broker.getEquity('RealTurtleTrade')
        self.logger.debug(f'breakEven:{breakEven}, equity:{equity}')

        # 大于的时候直接返回可用资金
        if breakEven < equity:
            return equity

        # 小于的时候用现在的权益除以保本资金作为百分比，可用资金为权益的这个百分比
        elif breakEven >= equity:
            return equity / breakEven * equity


    def initCheckOpen(self, symbol: str) -> bool:
        """
        由于在每个tick计算是否可以开仓带来了巨大的cpu与
        时间开销所以将这个一次性计算好然后实时的时候只从记录的dict读取
        """

        ts_symbol = KQ2ts(symbol)
        temp = self.contractInfo[self.contractInfo['ts_code'] == ts_symbol]
        assert len(temp) < 2, '查到合约不唯一，请检查！'
        delist_date = temp['delist_date'].iloc[0]
        delist_date = datetime.strptime(delist_date, '%Y%m%d')
        now = datetime.now()
        self.logger.debug(f'symbol:{symbol}, delist_date:{delist_date}, now:{now}')

        # 交割月，12月与1月，或者正好差一个月，才进行判断
        if delist_date.month == now.month:
            self.logger.debug(f'月份相等！')
            if delist_date.year == now.year:
                logger.debug('交割当月，直接退出')
                return False

        if delist_date.month - now.month == 1 or delist_date.month - now.month == -11:
            self.logger.debug(f'距离交割一个月！')
            count = 0
            for i, date in enumerate(self.tradeDate.cal_date):
                if i == 0:  # 当天不计算在内，因为当天不确定是否交易日，也就无法确定到底计算几天
                    self.logger.debug(f'i 为0')
                    continue
                if datetime.strptime(date, '%Y%m%d').month == delist_date.month:  # 找到与交割月同月时，就退出寻找
                    self.logger.debug(f'找到与交割月同月时，就退出寻')
                    return False
                if self.tradeDate.loc[i, 'is_open']:
                    count += 1
                if count >= 2:  # 除开今天还有两天以上，可以持仓。
                    self.logger.debug(f'除开今天还有两天以上，可以持仓。')
                    break
            else:
                assert False, '没有在循环内确定不可开仓，或者继续判断其他条件！请检查。'

        atr = self.calAtr(symbol)
        if np.isnan(atr):
            # atr 如果上市时间不够有可能返回nan这个时候等于技术指标不能进行计算所以直接返回不能开仓。
            return False
        N = self.getQuantity(symbol, atr)
        self.logger.debug(f'N是：{N}, atr是：{atr}')
        temp_symbol = symbol.strip(string.digits)
        multiplier = self.generalTickInfo.loc[self.generalTickInfo['contract_name'] == 'KQ.m@' + temp_symbol, 'contract_multiplier']
        margin_rate = self.generalTickInfo.loc[self.generalTickInfo['contract_name'] == 'KQ.m@' + temp_symbol, 'margin']
        # 倒数第二根日k的成交量与当日开盘时的持仓量都要大于100倍的我要交易的标准单位才可以开仓
        if self.feed[symbol + '1d'].getVolumeDataSeries().iloc[-11:-2].mean() > 100 * N \
                and self.feed[symbol + '1d'].getOpen_oiDataSeries().iloc[-10:].mean() > 100 * N:
            self.logger.debug(f'近10日k的成交量与当日开盘时的持仓量都要大于100倍的我要交易的标准单位可以开仓')
            if self.feed[symbol + '1d'].getVolumeDataSeries().iloc[-2] > 10000 and \
            self.feed[symbol + '1d'].getOpen_oiDataSeries().iloc[-1] > 10000:  # 持仓和成交量都要超过1万张，避免总是交易一些非主力的小合约
                return True
            else:
                self.logger.debug('合约持仓或成交未超过1万张，不开仓！')
                return False
        else:
            self.logger.debug(f'近10天日k的成交量与当日开盘时的持仓量都要大于100倍的我要交易的标准单位*不*可以开仓')
            return False


    def checkOpen(self, symbol: str) -> bool:
        """实时读取开始记录的是否可以开仓的dict"""
        return self.openMark[symbol]


    def calAtr(self, symbol):
        """获取指定品种的atr"""
        self.logger.debug(f'计算atr的symbol是:{symbol}')
        # # feed 内部是用 合约名 +  频率来作为key来存储k线数据的，海龟用的都是日线所以直接加1d
        # # self.logger.debug(f"数据是：\n{self.feed[symbol + '1d']}")  # test  log数据量太大，先注释掉
        # atr = talib.ATR(np.array(self.feed[symbol + '1d'].getHighDataSeries()),
        #             np.array(self.feed[symbol + '1d'].getLowDataSeries()),
        #             np.array(self.feed[symbol + '1d'].getCloseDataSeries()), self.atrPeriod)[-1]
        # return atr
        return self.allAtr[symbol]


    def stop(self):
        """用来给broker调用，停止策略时将数据记入文件"""
        with open(self.context.Record_path, 'r') as f:
            temp = json.load(f)

        with open(self.context.Record_path, 'w') as f:
            temp['RealTurtleTrade']['PriceVolumeAtr'] = self.PriceVolumeAtr
            json.dump(temp, f)
        self.exitMark = True

 
    def on_order_finished(self, order):
        logger.debug(f'RealTurtleTrade order finished filled price:{order.filledPrice}')

class SmaTurtleTrade(YhlzStreategy):
    """
    海龟交易策略
    """

    def __init__(self, feed, instruments, context, dictOfDataDf, atrPeriod=20, short=108, long=694):
        """
        初始化
        :parm feed pyalgotrade 的feed对象，装了所有csv数据。类似于dict可以用中括号取值。
        :parm instrument 包含所有category的list，用的是简写，如‘rb’，‘ag’
        :param context context 对象，装所有变量
        :parm atrPeriod atr的周期
        :parm short 唐奇安通道的短期
        :parm long 唐奇安通道的长期
        :parm dictOfDataDf 包含所有数据的dict，其中每一个category是一个df
        """
        super(SmaTurtleTrade, self).__init__(feed, 10000)

        self.feed = feed
        if isinstance(instruments, list):  # 对于不是多个品种的情况，进行判断，如果是字符串，用list包裹在存储
            self.instruments = instruments
        else:
            self.instruments = [instruments]
        self.atrPeriod = atrPeriod
        self.short = short  # * 300  # 测试
        self.long = long  # * 300  # 测试
        self.dictOfDateDf = dictOfDataDf
        self.context = context
        self.generalTickInfo = pd.read_csv(context.general_ticker_info_path)
        self.openPriceAndATR = {}  # 用于记录每个品种的开仓价格与当时的atr
        self.tech = {}
        # self.max = 0  # 用来存储最大的历史数据有多长， 以计算此时到了哪一根k线，以方便用 -(self.max - self.i) 来取技术指标
        for instrument in self.instruments:
            atr = talib.ATR(np.array(self.dictOfDateDf[instrument]['High']),
                            np.array(self.dictOfDateDf[instrument]['Low']),
                            np.array(self.dictOfDateDf[instrument]['Close']), self.atrPeriod)  # 返回的ndarray
            long = talib.SMA(np.array(self.dictOfDateDf[instrument]['Close']), self.long)
            short = talib.SMA(np.array(self.dictOfDateDf[instrument]['Close']), self.short)

            self.tech[instrument] = {'atr': atr, 'long': long,
                                     'short': short}
            # if len(atr) > self.max:
            #     self.max = len(atr)

        # self.i = 0  # 计数，用于确定计数指标的位置

    def onBars(self, bars):
        barTime = bars.getDateTime()
        self.equity = self.getBroker().getEquity()
        orders = []
        allAtr = {}
        postion = self.getBroker().getPositions()
        readyInstrument = bars.getInstruments()
        for instrument in self.instruments:
            # t1 = time.time()
            # atr = talib.ATR(np.array(self.feed[instrument].getHighDataSeries()),
            #                 np.array(self.feed[instrument].getLowDataSeries()),
            #                 np.array(self.feed[instrument].getCloseDataSeries()), self.atrPeriod)[-1]  # 返回的ndarray
            # if np.isnan(atr):  # 为nan说明数据还不够，不做计算。
            #     continue
            # allAtr[instrument] = atr
            # quantity = self.getQuantity(instrument, atr)
            # long_upper = talib.MAX(np.array(self.feed[instrument].getHighDataSeries()), self.long)
            # long_lower = talib.MIN(np.array(self.feed[instrument].getLowDataSeries()), self.long)
            # short_upper = talib.MAX(np.array(self.feed[instrument].getHighDataSeries()), self.short)
            # short_lower = talib.MIN(np.array(self.feed[instrument].getLowDataSeries()), self.short)

            if instrument not in readyInstrument:  # 如果此时没有这个品种的bar 说明还没开始或者别的品种的夜盘和它时间冲突
                temp = self.dictOfDateDf[instrument][self.dictOfDateDf[instrument]['Date Time'] < barTime]
                if temp.index.empty:  # 如果index是空值，则说明此时这个品种还没有开始有数据
                    i = 0
                else:
                    i = temp.index[-1]  # 有数据，就说明是中间有夜盘时间不对齐的问题，用前一天的atr来代替

                atr = self.tech[instrument]['atr'][i] * 10  # 测试
                allAtr[instrument] = atr
                # 对于这种取后一天的atr来假装，避免后面取atr 之前开仓所以有持仓，但此时没有atr的报错
                continue
            i = self.dictOfDateDf[instrument][self.dictOfDateDf[instrument]['Date Time'] == barTime].index[0]
            atr = self.tech[instrument]['atr'][i]  # * 50#测试
            allAtr[instrument] = atr

            if np.isnan(atr):  # 为nan说明数据还不够，不做计算。
                continue

            #  找到这个时间在df中的位置
            quantity = self.getQuantity(instrument, atr)
            long = self.tech[instrument]['long'][i - 1:i + 1]  # 取出到此时的最后两个
            short = self.tech[instrument]['short'][i - 1:i + 1]
            # t2 = time.time()
            # print(t2 - t1)
            # 开仓。
            if long[-1] < short[-1] and long[-2] > short[-2] and postion.get(instrument, 0) == 0:  # 当期上界变高表示创新高，新低同理
                ret = self.getBroker().createMarketOrder(Action.BUY, instrument, quantity)
                if instrument in postion:
                    ret1 = self.getBroker().createMarketOrder(Action.BUY_TO_COVER, instrument,
                                                              postion[instrument])
                    orders.append(ret1)
                self.openPriceAndATR[instrument] = [bars.getBar(instrument).getClose(), atr]  # 默认以收盘价开仓
                print('long')
                print(bars.getDateTime())
                print(instrument)
                orders.append(ret)

            elif short[-1] < long[-1] and long[-2] < short[-2] and postion.get(instrument, 0) == 0:
                ret = self.getBroker().createMarketOrder(Action.SELL_SHORT, instrument, quantity)
                if instrument in postion:
                    ret1 = self.getBroker().createMarketOrder(Action.SELL, instrument,
                                                              postion[instrument])
                    orders.append(ret1)
                self.openPriceAndATR[instrument] = [bars.getBar(instrument).getClose(), atr]  # 默认以收盘价开仓
                print('short')
                print(bars.getDateTime())
                print(instrument)
                orders.append(ret)


            # 加仓 或止损
            elif instrument in self.openPriceAndATR:  # 表示已有持仓
                if postion.get(instrument, 0) > 0:  # 持有多仓
                    if self.openPriceAndATR[instrument][0] + 0.5 * atr < bars.getBar(instrument).getClose():
                        # 且价格超过了上次开仓价加0.5atr在加仓
                        ret = self.getBroker().createMarketOrder(Action.BUY, instrument,
                                                                 quantity)
                        self.openPriceAndATR[instrument][0] = bars.getBar(instrument).getClose()
                        print('add long')
                        print(bars.getDateTime())
                        print(instrument)
                        orders.append(ret)

                    elif self.openPriceAndATR[instrument][0] - 0.5 * atr > bars.getBar(instrument).getClose():
                        # 且价格小于了上次开仓价减0.5atr，则止损
                        ret = self.getBroker().createMarketOrder(Action.SELL, instrument,
                                                                 abs(postion[instrument]))
                        self.openPriceAndATR.pop(instrument)
                        print('stop long')
                        print(bars.getDateTime())
                        print(instrument)
                        orders.append(ret)

                elif postion.get(instrument, 0) < 0:  # 持有空仓
                    if self.openPriceAndATR[instrument][0] - 0.5 * atr > bars.getBar(instrument).getClose():
                        # 且价格超过了开仓价加0.5atr 则加空
                        ret = self.getBroker().createMarketOrder(Action.SELL_SHORT, instrument,
                                                                 quantity)
                        self.openPriceAndATR[instrument][0] = bars.getBar(instrument).getClose()
                        print('add short')
                        print(bars.getDateTime())
                        print(instrument)
                        orders.append(ret)

                    elif self.openPriceAndATR[instrument][0] + 0.5 * atr < bars.getBar(instrument).getClose():
                        # 且价格大于了上次开仓价加0.5atr，则止损
                        ret = self.getBroker().createMarketOrder(Action.BUY_TO_COVER, instrument,
                                                                 abs(postion[instrument]))
                        self.openPriceAndATR.pop(instrument)
                        print('stop short')
                        print(bars.getDateTime())
                        print(instrument)
                        orders.append(ret)
            # t3 = time.time()
            # print(t3 - t2)

        # t3 = time.time()
        allPos = 0
        for instrument in postion:
            allPos += round(postion[instrument] / self.getQuantity(instrument, allAtr[instrument]))
            # 看某个品种有多少个单位的持仓，按照现在的atr来计算
        if allPos >= 10:
            open_mark = False  # 达到10个单位，不再开仓
        else:
            open_mark = True

        for item in orders:
            item.setGoodTillCanceled(True)
            item.setAllOrNone(True)
            action = item.getAction()
            if action == Action.SELL or action == Action.BUY_TO_COVER:  # 平仓的都可以
                self.getBroker().submitOrder(item)
            else:  # 开仓的情况
                if open_mark:
                    ins = item.getInstrument()
                    exist = round(postion.get(ins, 0) / self.getQuantity(ins, allAtr[ins]))
                    if exist <= 3:  # 如果单个品种小于3个单位的持仓，就可以开
                        self.getBroker().submitOrder(item)
                        allPos += 1
                    if allPos >= 10:
                        open_mark = False
        # self.i += 1  # 自增以移向下一个计数指标的值
        t4 = time.time()
        # print(t4 - t3)

    def getQuantity(self, instrument, atr):
        """
        计算此时可以开多少张
        :return:
        """

        quantity = self.equity
        # quantity = 1000000 测试

        KQFileName = self.context.categoryToFile[instrument]

        KQmultiplier = \
            self.generalTickInfo.loc[self.generalTickInfo['index_name'] == KQFileName, 'contract_multiplier'].iloc[0]

        # res = int(quantity / atr / 100 / KQmultiplier)  # 向下取整
        res = int(quantity / atr / 100 / 20)  # 由于目前回测系统没有考虑合约乘数，不需要除以合约乘数

        if res:
            return res
        else:
            return 1  # 至少开1手
        # 账户的1%的权益，除去atr值，再除去合约乘数，即得张数。表示一个atr的标准波动让账户的权益变动1%


class RandomordersStratey(YhlzStreategy):
    """随机发单的策略"""

    def __init__(self, feed, instrument, context, dictOfDataDf):
        super(RandomordersStratey, self).__init__(feed)
        self.__instrument = instrument

    def onBars(self, bars):
        pos = self.getBroker().getPositions()
        if True:
            # if pos.shape[0] == 1:
            num = randint(0, 1)
            if num == 0:
                direction = Action.BUY
            elif num == 1:
                direction = Action.SELL_SHORT

            logger.debug('买卖方向是：' + str(direction))
            # elif num == 2:
            #     direction = Action.SELL
            # elif num == 3:
            #     direction = Action.BUY_TO_COVER

            a = self.getBroker().createLimitOrder(direction, self.transInstrument(self.__instrument), 10)
            if a.virDirection == Action.BUY:
                a.price -= 10
            else:
                a.price += 10
            self.getBroker().submitOrder(a)
        else:
            rand = choice([0, 1])
            if rand:
                return
            for i in range(pos.shape[0]):
                if pos.loc[i, 'account'] != 1:
                    if pos.loc[i, 'direction'] == 'BUY':
                        direction = Action.SELL
                        a = self.getBroker().createMarketOrder(direction, pos.loc[i, 'contract'], pos.loc[i, 'volume'])
                        self.getBroker().submitOrder(a)
                    elif pos.loc[i, 'direction'] == 'SELL':
                        direction = Action.BUY_TO_COVER
                        a = self.getBroker().createMarketOrder(direction, pos.loc[i, 'contract'], pos.loc[i, 'volume'])
                        self.getBroker().submitOrder(a)
                    break

    def transInstrument(self, instrument):
        # return instrument
        return 'SHFE.rb2110'


class BP_GAStg(YhlzStreategy):

    def __init__(self, feed, instruments, context, data, learnRate=0.2, parameter=None):
        """

        :param feed:
        :param instruments:
        :param context:
        :param data:
        :param learnRate:
        :param parameter: 传给神经网络的初始参数，默认不需要传，但是参数优化的时候需要，类型是四个矩阵的list，大小分别是两个网络的【输入长度，隐层长度】与【隐层长度，输出长度】。
        """
        super().__init__(feed)
        self.feed = feed
        # 处理行情数据的神经网络，将最近1000根bar的历史行情输入，输出对行情是应该多还是空的一个判断。
        self.neroForData = neuralNetworkForBackTest(100 * len(instruments), 66 * len(instruments), len(instruments),
                                                    learnRate)
        # 处理已有持仓与行情信号的网络，根据已有的持仓和行情网络计算的结果作为输入，输出最终，应该有的持仓。
        self.neroForPos = neuralNetworkForBackTest(2 * len(instruments), int(4 * len(instruments) / 3),
                                                   len(instruments), learnRate)
        self.instruments = instruments
        self.pos = self.getBroker().getPositions()

        if parameter:
            self.neroForData.setPara(parameter[0], parameter[1])
            self.neroForPos.setPara(parameter[2], parameter[3])

    def onBars(self, bars):
        input = np.zeros(shape=(100 * len(self.instruments), 1))
        posList = []
        orders = []
        equity = self.getBroker().getEquity()
        i = 0
        for instrument in self.instruments:
            temp = np.array(self.feed[instrument].getPriceDataSeries())
            maxValue = np.max(temp)
            minValue = np.min(temp)

            if len(temp) < 100:
                return
            input[i * 100:i * 100 + 100, 0] = ((temp - minValue) / (maxValue - minValue + 0.01) + 0.01) * 0.99
            i += 1
            posList.append(self.pos.get(instrument, 0) * temp[-1] / equity)

        # 查询此时对行情的判断
        res = self.neroForData.query(input.T)

        # 传入持仓网络
        res = self.neroForPos.query(res.T.tolist()[0] + posList)
        res = (res - 0.5) * 2  # 神经网络的输入是s型函数，所以值域为[0, 1],现在变为[-1, 1]任意一个品种都可以满仓，但是满仓太多会导致巨亏，

        for i in range(len(self.instruments)):
            temp = pct2vol(self.instruments[i], res[i, 0], posList[i], equity,
                           self.feed[self.instruments[i]].getPriceDataSeries()[-1])
            if temp:
                if len(temp) == 2:
                    for item in temp:
                        orders.append(self.getBroker().createMarketOrder(item[0], item[1], abs(item[2])))
                elif len(temp) == 3:
                    orders.append(self.getBroker().createMarketOrder(temp[0], temp[1], abs(temp[2])))
                else:
                    raise Exception('pct2vol函数，返回值有问题！')

        for item in orders:
            self.getBroker().submitOrder(item)


    def getPara(self):
        return self.neroForData.getPara, self.neroForPos.getPara
